# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import abc
import re

from typing import Optional


class BenchmarkCommand(abc.ABC):
    """Abstracts a benchmark command."""

    def __init__(
        self,
        benchmark_binary: str,
        model_name: str,
        num_threads: int,
        num_runs: int,
        driver: Optional[str] = None,
        taskset: Optional[str] = None,
    ):
        self.benchmark_binary = benchmark_binary
        self.model_name = model_name
        self.taskset = taskset
        self.num_threads = num_threads
        self.num_runs = num_runs
        self.driver = driver
        self.args = []

    @property
    @abc.abstractmethod
    def runtime(self):
        pass

    @abc.abstractmethod
    def parse_latency_from_output(self, output: str) -> float:
        pass

    def generate_benchmark_command(self) -> list[str]:
        """Returns a list of strings that correspond to the command to be run."""
        command = []
        if self.taskset:
            command.append("taskset")
            command.append(str(self.taskset))
        command.append(self.benchmark_binary)
        command.extend(self.args)
        return command


class TFLiteBenchmarkCommand(BenchmarkCommand):
    """Represents a TFLite benchmark command."""

    def __init__(
        self,
        benchmark_binary: str,
        model_name: str,
        model_path: str,
        num_threads: int,
        num_runs: int,
        taskset: Optional[str] = None,
    ):
        super().__init__(
            benchmark_binary, model_name, num_threads, num_runs, taskset=taskset
        )
        self.args.append("--graph=" + model_path)
        self._latency_large_regex = re.compile(
            r".*?Inference \(avg\): (\d+.?\d*e\+?\d*).*"
        )
        self._latency_regex = re.compile(r".*?Inference \(avg\): (\d+).*")

    @property
    def runtime(self):
        return "tflite"

    def parse_latency_from_output(self, output: str) -> float:
        # First match whether a large number has been recorded e.g. 1.18859e+06.
        matches = self._latency_large_regex.search(output)
        if not matches:
            # Otherwise, regular number e.g. 71495.6.
            matches = self._latency_regex.search(output)

        latency_ms = 0
        if matches:
            latency_ms = float(matches.group(1)) / 1000
        else:
            print("Warning! Could not parse latency. Defaulting to 0ms.")
        return latency_ms

    def generate_benchmark_command(self) -> list[str]:
        command = super().generate_benchmark_command()
        if self.driver == "gpu":
            command.append("--use_gpu=true")
        command.append("--num_threads=" + str(self.num_threads))
        command.append("--num_runs=" + str(self.num_runs))
        return command


class IreeBenchmarkCommand(BenchmarkCommand):
    """Represents an IREE benchmark command."""

    def __init__(
        self,
        benchmark_binary: str,
        model_name: str,
        model_path: str,
        num_threads: int,
        num_runs: int,
        taskset: Optional[str] = None,
    ):
        super().__init__(
            benchmark_binary, model_name, num_threads, num_runs, taskset=taskset
        )
        self.args.append("--module=" + model_path)
        self._latency_regex = re.compile(
            r".*?BM_main/process_time/real_time_mean\s+(.*?) ms.*"
        )

    @property
    def runtime(self):
        return "iree"

    def parse_latency_from_output(self, output: str) -> float:
        matches = self._latency_regex.search(output)
        latency_ms = 0
        if matches:
            latency_ms = float(matches.group(1))
        else:
            print("Warning! Could not parse latency. Defaulting to 0ms.")
        return latency_ms

    def generate_benchmark_command(self) -> list[str]:
        command = super().generate_benchmark_command()
        command.append("--device=" + self.driver)
        command.append("--task_topology_max_group_count=" + str(self.num_threads))
        command.append("--benchmark_repetitions=" + str(self.num_runs))
        return command
